1   package org.apache.solr.common.util;
2   
3   import java.lang.invoke.MethodHandles;
4   import java.util.ArrayList;
5   import java.util.Collection;
6   
7   /*
8    * Licensed to the Apache Software Foundation (ASF) under one or more
9    * contributor license agreements.  See the NOTICE file distributed with
10   * this work for additional information regarding copyright ownership.
11   * The ASF licenses this file to You under the Apache License, Version 2.0
12   * (the "License"); you may not use this file except in compliance with
13   * the License.  You may obtain a copy of the License at
14   *
15   *     http://www.apache.org/licenses/LICENSE-2.0
16   *
17   * Unless required by applicable law or agreed to in writing, software
18   * distributed under the License is distributed on an "AS IS" BASIS,
19   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20   * See the License for the specific language governing permissions and
21   * limitations under the License.
22   */
23  
24  import java.util.Enumeration;
25  import java.util.HashMap;
26  import java.util.List;
27  import java.util.Map;
28  import java.util.Set;
29  import java.util.concurrent.BlockingQueue;
30  import java.util.concurrent.Callable;
31  import java.util.concurrent.ConcurrentHashMap;
32  import java.util.concurrent.CopyOnWriteArrayList;
33  import java.util.concurrent.ExecutorService;
34  import java.util.concurrent.LinkedBlockingQueue;
35  import java.util.concurrent.RejectedExecutionHandler;
36  import java.util.concurrent.SynchronousQueue;
37  import java.util.concurrent.ThreadFactory;
38  import java.util.concurrent.ThreadPoolExecutor;
39  import java.util.concurrent.TimeUnit;
40  import java.util.concurrent.atomic.AtomicReference;
41  
42  import org.slf4j.Logger;
43  import org.slf4j.LoggerFactory;
44  import org.slf4j.MDC;
45  
46  
47  public class ExecutorUtil {
48    private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
49  
50    private static volatile List<InheritableThreadLocalProvider> providers = new ArrayList<>();
51  
52    public synchronized static void addThreadLocalProvider(InheritableThreadLocalProvider provider) {
53      for (InheritableThreadLocalProvider p : providers) {//this is to avoid accidental multiple addition of providers in tests
54        if (p.getClass().equals(provider.getClass())) return;
55      }
56      List<InheritableThreadLocalProvider> copy = new ArrayList<>(providers);
57      copy.add(provider);
58      providers = copy;
59    }
60  
61    /** Any class which wants to carry forward the threadlocal values to the threads run
62     * by threadpools must implement this interface and the implementation should be
63     * registered here
64     */
65    public interface InheritableThreadLocalProvider {
66      /**This is invoked in the parent thread which submitted a task.
67       * copy the necessary Objects to the ctx. The object that is passed is same
68       * across all three methods
69       */
70      public void store(AtomicReference<?> ctx);
71  
72      /**This is invoked in the Threadpool thread. set the appropriate values in the threadlocal
73       * of this thread.     */
74      public void set(AtomicReference<?> ctx);
75  
76      /**This method is invoked in the threadpool thread after the execution
77       * clean all the variables set in the set method
78       */
79      public void clean(AtomicReference<?> ctx);
80    }
81  
82    // ** This will interrupt the threads! ** Lucene and Solr do not like this because it can close channels, so only use
83    // this if you know what you are doing - you probably want shutdownAndAwaitTermination.
84    // Marked as Deprecated to discourage use.
85    @Deprecated
86    public static void shutdownWithInterruptAndAwaitTermination(ExecutorService pool) {
87      pool.shutdownNow(); // Cancel currently executing tasks - NOTE: this interrupts!
88      boolean shutdown = false;
89      while (!shutdown) {
90        try {
91          // Wait a while for existing tasks to terminate
92          shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
93        } catch (InterruptedException ie) {
94          // Preserve interrupt status
95          Thread.currentThread().interrupt();
96        }
97      }
98    }
99    
100   // ** This will interrupt the threads! ** Lucene and Solr do not like this because it can close channels, so only use
101   // this if you know what you are doing - you probably want shutdownAndAwaitTermination.
102   // Marked as Deprecated to discourage use.
103   @Deprecated
104   public static void shutdownAndAwaitTerminationWithInterrupt(ExecutorService pool) {
105     pool.shutdown(); // Disable new tasks from being submitted
106     boolean shutdown = false;
107     boolean interrupted = false;
108     while (!shutdown) {
109       try {
110         // Wait a while for existing tasks to terminate
111         shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
112       } catch (InterruptedException ie) {
113         // Preserve interrupt status
114         Thread.currentThread().interrupt();
115       }
116       if (!shutdown && !interrupted) {
117         pool.shutdownNow(); // Cancel currently executing tasks - NOTE: this interrupts!
118         interrupted = true;
119       }
120     }
121   }
122 
123   public static void shutdownAndAwaitTermination(ExecutorService pool) {
124     pool.shutdown(); // Disable new tasks from being submitted
125     boolean shutdown = false;
126     while (!shutdown) {
127       try {
128         // Wait a while for existing tasks to terminate
129         shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
130       } catch (InterruptedException ie) {
131         // Preserve interrupt status
132         Thread.currentThread().interrupt();
133       }
134     }
135   }
136 
137   /**
138    * See {@link java.util.concurrent.Executors#newFixedThreadPool(int, ThreadFactory)}
139    */
140   public static ExecutorService newMDCAwareFixedThreadPool(int nThreads, ThreadFactory threadFactory) {
141     return new MDCAwareThreadPoolExecutor(nThreads, nThreads,
142         0L, TimeUnit.MILLISECONDS,
143         new LinkedBlockingQueue<Runnable>(),
144         threadFactory);
145   }
146 
147   /**
148    * See {@link java.util.concurrent.Executors#newSingleThreadExecutor(ThreadFactory)}
149    */
150   public static ExecutorService newMDCAwareSingleThreadExecutor(ThreadFactory threadFactory) {
151     return new MDCAwareThreadPoolExecutor(1, 1,
152             0L, TimeUnit.MILLISECONDS,
153             new LinkedBlockingQueue<Runnable>(),
154             threadFactory);
155   }
156 
157   /**
158    * See {@link java.util.concurrent.Executors#newCachedThreadPool(ThreadFactory)}
159    */
160   public static ExecutorService newMDCAwareCachedThreadPool(ThreadFactory threadFactory) {
161     return new MDCAwareThreadPoolExecutor(0, Integer.MAX_VALUE,
162         60L, TimeUnit.SECONDS,
163         new SynchronousQueue<Runnable>(),
164         threadFactory);
165   }
166 
167   @SuppressForbidden(reason = "class customizes ThreadPoolExecutor so it can be used instead")
168   public static class MDCAwareThreadPoolExecutor extends ThreadPoolExecutor {
169 
170     private static final int MAX_THREAD_NAME_LEN = 512;
171 
172     public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
173       super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
174     }
175 
176     public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
177       super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
178     }
179 
180     public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) {
181       super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
182     }
183 
184     public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) {
185       super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
186     }
187 
188     @Override
189     public void execute(final Runnable command) {
190       final Map<String, String> submitterContext = MDC.getCopyOfContextMap();
191       StringBuilder contextString = new StringBuilder();
192       if (submitterContext != null) {
193         Collection<String> values = submitterContext.values();
194 
195         for (String value : values) {
196           contextString.append(value + " ");
197         }
198         if (contextString.length() > 1) {
199           contextString.setLength(contextString.length() - 1);
200         }
201       }
202 
203       String ctxStr = contextString.toString().replace("/", "//");
204       final String submitterContextStr = ctxStr.length() <= MAX_THREAD_NAME_LEN ? ctxStr : ctxStr.substring(0, MAX_THREAD_NAME_LEN);
205       final Exception submitterStackTrace = new Exception("Submitter stack trace");
206       final List<InheritableThreadLocalProvider> providersCopy = providers;
207       final ArrayList<AtomicReference> ctx = providersCopy.isEmpty() ? null : new ArrayList<AtomicReference>(providersCopy.size());
208       if (ctx != null) {
209         for (int i = 0; i < providers.size(); i++) {
210           AtomicReference reference = new AtomicReference();
211           ctx.add(reference);
212           providersCopy.get(i).store(reference);
213         }
214       }
215       super.execute(new Runnable() {
216         @Override
217         public void run() {
218           isServerPool.set(Boolean.TRUE);
219           if (ctx != null) {
220             for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).set(ctx.get(i));
221           }
222           Map<String, String> threadContext = MDC.getCopyOfContextMap();
223           final Thread currentThread = Thread.currentThread();
224           final String oldName = currentThread.getName();
225           if (submitterContext != null && !submitterContext.isEmpty()) {
226             MDC.setContextMap(submitterContext);
227             currentThread.setName(oldName + "-processing-" + submitterContextStr);
228           } else {
229             MDC.clear();
230           }
231           try {
232             command.run();
233           } catch (Throwable t) {
234             if (t instanceof OutOfMemoryError) {
235               throw t;
236             }
237             log.error("Uncaught exception {} thrown by thread: {}", t, currentThread.getName(), submitterStackTrace);
238             throw t;
239           } finally {
240             isServerPool.remove();
241             if (threadContext != null && !threadContext.isEmpty()) {
242               MDC.setContextMap(threadContext);
243             } else {
244               MDC.clear();
245             }
246             if (ctx != null) {
247               for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).clean(ctx.get(i));
248             }
249             currentThread.setName(oldName);
250           }
251         }
252       });
253     }
254   }
255 
256   private static final ThreadLocal<Boolean> isServerPool = new ThreadLocal<>();
257 
258   public static boolean isSolrServerThread() {
259     return Boolean.TRUE.equals(isServerPool.get());
260   }
261 
262 }